#!/usr/bin/env python3
#
# Copyright (C) Hamish Coleman
# SPDX-License-Identifier: GPL-2.0-only
#
# Simple script to query the management interface of a running n3n edge node

import argparse
import base64
import collections
import http.client
import json
import os
import socket
import time
import urllib.parse
import urllib.request


class UnixHandler(urllib.request.BaseHandler):
    def __init__(self, basepath):
        self.basepath = basepath

    def unix_open(self, req):
        pathname = os.path.join(self.basepath, req.host, "mgmt")

        h = http.client.HTTPConnection(req.host, timeout=req.timeout)

        headers = dict(req.unredirected_hdrs)
        headers.update({k: v for k, v in req.headers.items()
                        if k not in headers})
        headers = {name.title(): val for name, val in headers.items()}

        def unix_connect():
            h.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            h.sock.connect(pathname)

        # hackitty hack - if only we could avoid setsockopt() in the default
        # unix_connect, or if it didnt throw an error on unsupported ops
        h.connect = unix_connect

        try:
            h.request(req.get_method(), req.selector, req.data, headers)

        except TimeoutError as err:  # timeout error
            raise urllib.error.URLError(err)
        r = h.getresponse()

        if h.sock:
            h.sock.close()
            h.sock = None

        r.url = req.get_full_url()
        return r


class JsonRPC:
    def __init__(self, url):
        self.debug = False

        if url is None:
            raise ValueError("need url")

        self.url = url + "/v1"
        self.timeout = 5
        self.id = 0
        self.username = "unused"
        self.password = None

    def _data(self, method, params):
        self.id += 1
        data = {
            "jsonrpc": "2.0",
            "id": self.id,
            "method": method,
            "params": params
        }
        json_data = json.dumps(data).encode('utf8')
        if self.debug:
            print("data:", json_data)

        return json_data

    def _request_obj(self, method, params):
        data = self._data(method, params)

        req = urllib.request.Request(
            method="POST",
            url=self.url,
            headers={
                "Content-Type": "application/json",
            },
            data=data,
        )
        if self.password is not None:
            val = f"{self.username}:{self.password}".encode('utf8')
            encoded = base64.b64encode(val)
            req.add_header("Authorization", b"Basic " + encoded)
        return req

    class Unauthenticated(Exception):
        """Raised if the request needs authentication added"""
        pass

    class Overflow(Exception):
        """Raised if the daemon overflows its internal buffer"""
        def __init__(self, count):
            self.count = count
            super().__init__()

    def get_nopagination(self, method, params=None):
        """Makes the RPC request, with no automated pagination retry"""

        req = self._request_obj(method, params)

        try:
            r = urllib.request.urlopen(req, timeout=self.timeout)
        except urllib.error.HTTPError as e:
            if e.code == 401:
                raise JsonRPC.Unauthenticated
            raise e

        if r.status == 401:
            raise JsonRPC.Unauthenticated

        body = r.read()

        if self.debug:
            print("reply:", body)
        body = json.loads(body)

        if r.status == 507:
            raise JsonRPC.Overflow(body["error"]["data"]["count"])
        if r.status != 200:
            raise ValueError(f"urllib request got {r.status} {r.reason}")

        if "result" not in body:
            raise ValueError("jsonrpc error")

        assert (body['id'] == str(self.id))

        return body['result']

    def get(self, method, offset=None, limit=None, params=None):
        if params is not None and len(params) == 0:
            # This can happen with the args passed from CLI
            params = None

        # Populate the params if needed
        if offset is not None:
            if params is None:
                params = dict()
            params["offset"] = offset
        if limit is not None:
            if params is None:
                params = dict()
            params["limit"] = limit

        # Try once, possibly without pagination, to see if we overflow. And
        # if so, to detect a good size for the limit
        try:
            return self.get_nopagination(method, params)
        except JsonRPC.Overflow as e:
            count_max = e.count

        if params is None:
            params = dict()

        if not isinstance(params, dict):
            """Since we overflowed, params must be compatible"""
            raise ValueError(f"Cannot use params={params} with autopagination")

        params["limit"] = count_max - 1
        params["offset"] = 0

        result = []

        while True:
            try:
                partial = self.get_nopagination(method, params)
            except JsonRPC.Overflow as e:
                if count_max == e.count:
                    # If the limit didnt get smaller, there is a problem with
                    # the daemon, so we just bail out
                    raise

                # reduce our asking size and try again
                count_max = e.count
                params["limit"] = count_max
                continue

            if not isinstance(partial, list):
                # The current API only returns lists, but there could be
                # dicts in the future, catch this
                raise NotImplementedError

            result.extend(partial)
            params["offset"] = len(result)

            if len(partial) < params["limit"]:
                # We fetched less than we asked for, so we must be at the end
                # of the list
                return result


def str_table(rows, columns, orderby):
    """Given an array of dicts, do a simple table print"""
    result = list()
    widths = collections.defaultdict(lambda: 0)

    if len(rows) == 0:
        # No data to show, be sure not to truncate the column headings
        for col in columns:
            widths[col] = len(col)
    else:
        for row in rows:
            for col in columns:
                if col in row:
                    widths[col] = max(widths[col], len(str(row[col])))

    for col in columns:
        if widths[col] == 0:
            widths[col] = 1
        result += "{:{}.{}} ".format(col, widths[col], widths[col])
    result += "\n"

    if orderby is not None:
        rows = sorted(rows, key=lambda row: row.get(orderby, 0))

    for row in rows:
        for col in columns:
            if col in row:
                data = row[col]
            else:
                data = ''
            result += "{:{}} ".format(data, widths[col])
        result += "\n"

    return ''.join(result)


def num2timestr(seconds):
    """Convert a number of seconds into a human time"""

    if seconds == 0:
        return "now"

    days, seconds = divmod(seconds, (60*60*24))
    hours, seconds = divmod(seconds, (60*60))
    minutes, seconds = divmod(seconds, 60)

    fields = 0
    r = []
    if days:
        r += [f"{days}d"]
        fields += 1
    if fields < 2 and hours:
        r += [f"{hours}h"]
        fields += 1
    if fields < 2 and minutes:
        r += [f"{minutes}m"]
        fields += 1
    if fields < 2 and seconds:
        r += [f"{seconds}s"]
        fields += 1
    return "".join(r)


def subcmd_show_supernodes(rpc, args):
    rows = rpc.get('get_supernodes')
    columns = [
        'version',
        'current',
        'macaddr',
        'sockaddr',
        'uptime',
        'last_seen',
    ]

    now = int(time.time())
    for row in rows:
        row["last_seen"] = num2timestr(now - row["last_seen"])

    return str_table(rows, columns, args.orderby)


def subcmd_show_edges(rpc, args):
    rows = rpc.get('get_edges')
    columns = [
        'mode',
        'ip4addr',
        'macaddr',
        'sockaddr',
        'desc',
        'last_seen',
    ]

    now = int(time.time())
    for row in rows:
        row["last_seen"] = num2timestr(now - row["last_seen"])

    return str_table(rows, columns, args.orderby)


def subcmd_mac(rpc, args):
    if args.orderby is None:
        args.orderby = "mac"

    rows = rpc.get('get_mac')
    columns = [
        '_type',
        'mac',
        'dest',
        'age',
    ]

    now = int(time.time())
    for row in rows:
        row["age"] = num2timestr(now - int(row["last_seen"]))

    return str_table(rows, columns, args.orderby)


def subcmd_show_help(rpc, args):
    result = 'Commands with pretty-printed output:\n\n'
    for name, cmd in subcmds.items():
        result += "{:12} {}\n".format(name, cmd['help'])

    result += "\n"
    result += "Possble remote commands:\n"
    result += "(those without a pretty-printer will pass-through)\n\n"

    try:
        rows = rpc.get('help')
        for row in rows:
            result += "{:15} {}\n".format(row['method'], row['desc'])
    except urllib.error.HTTPError:
        result += "Error requesting help data"

    return result


subcmds = {
    'help': {
        'func': subcmd_show_help,
        'help': 'Show available commands',
    },
    'supernodes': {
        'func': subcmd_show_supernodes,
        'help': 'Show the list of supernodes',
    },
    'edges': {
        'func': subcmd_show_edges,
        'help': 'Show the list of edges/peers',
    },
    'mac': {
        'func': subcmd_mac,
        'help': 'Show the mac address routing information',
    },
}


def subcmd_default(rpc, args):
    """Just pass command through to edge"""
    method = args.cmd
    params = args.args
    rows = rpc.get(method, params=params)
    return json.dumps(rows, sort_keys=True, indent=4)


def main():
    ap = argparse.ArgumentParser(
            description='Query the running local n3n edge')
    ap.add_argument('-s', '--sessionname', action='store',
                    help='Which session to use')
    ap.add_argument('-u', '--mgmturl', action='store',
                    help='Management API URL')
    ap.add_argument('-k', '--key', action='store',
                    help='Password for mgmt commands')
    ap.add_argument('-d', '--debug', action='store_true',
                    help='Also show raw internal data')
    ap.add_argument('--raw', action='store_true',
                    help='Force cmd to avoid any pretty printing')
    ap.add_argument('--orderby', action='store',
                    help='Hint to a pretty printer on how to sort')

    ap.add_argument('cmd', action='store',
                    help='Command to run (try "help" for list)')
    ap.add_argument('args', action='store', nargs="*",
                    help='Optional args for the command')

    args = ap.parse_args()

    # Sadly, the more than a decade old improvement to the expected location
    # of the # run dir has not made it everywhere (eg OpenWrt), so we need to
    # probe for it
    if os.path.exists("/run"):
        rundir = "/run"
    elif os.path.exists("/var/run"):
        rundir = "/var/run"
    else:
        raise ValueError("No run directory found")

    opener = urllib.request.build_opener(UnixHandler(f"{rundir}/n3n"))
    urllib.request.install_opener(opener)

    if args.sessionname is not None:
        args.mgmturl = "unix://" + args.sessionname

    if args.mgmturl is None:
        # try to guess the sessionname
        try:
            sessiondirs = os.listdir(f"{rundir}/n3n")
        except FileNotFoundError:
            sessiondirs = []

        sessions = []
        for dirname in sessiondirs:
            if os.path.exists(f"{rundir}/n3n/{dirname}/mgmt"):
                sessions.append(dirname)

        if len(sessions) == 0:
            print("Error: no sessions found, please specify --url")
            exit(1)
        elif len(sessions) == 1:
            # There is only one session, so choose that one
            args.mgmturl = "unix://" + sessions[0]
        else:
            print("Error: found multiple sessions:")
            print()
            print("\t", " ".join(sessions))
            print()
            print("please use --sessionname to choose one")
            exit(1)

    if args.raw or (args.cmd not in subcmds):
        func = subcmd_default
    else:
        func = subcmds[args.cmd]['func']

    rpc = JsonRPC(args.mgmturl)
    rpc.debug = args.debug
    rpc.password = args.key

    try:
        result = func(rpc, args)
    except JsonRPC.Unauthenticated:
        print("This request requires an authentication key")
        exit(1)
    except FileNotFoundError:
        print("Could not find unix socket (is the session running?)")
        exit(1)
    except ConnectionRefusedError:
        print("Connection refused (is the session running?)")
        exit(1)
    except socket.timeout as e:
        print(e)
        exit(1)

    print(result)


if __name__ == '__main__':
    main()
